package shell

import (
	"bufio"
	"errors"
	"gosh/config"
	"gosh/pkg/model"
	"io"
	"io/ioutil"
	"os"
	"path"
	"regexp"
	"strings"
	"time"
)

var SSHConfig = &Config{}
var bastionMap = make(map[string]*Bastion)
var hostMap = make(map[string]*Host)

type Config struct {
	Inventory            string
	Host                 string
	Hosts                []string
	Port                 int
	User                 string
	Password             string
	Key                  string
	KeyPassphrase        string
	BastionHost          string
	BastionPort          int
	BastionUser          string
	BastionPassword      string
	BastionKey           string
	BastionKeyPassphrase string
	Fork                 int
	ConnectTimeout       time.Duration
	PrintFields          []string
	PrintStatus          []string
	Labels               []string
	Condition            map[string]string
	Args                 []string
}

func (c *Config) labelToCondition() {
	for _, label := range c.Labels {
		match, _ := regexp.Match(".+=.+", []byte(label))
		if match {
			split := strings.Split(label, "=")
			c.Condition[split[0]] = split[1]
		}
	}
}

func (c *Config) Parse() error {
	// 获取主机地址列表
	c.Hosts = make([]string, 0, 10)
	if c.Host != "" {
		c.Hosts = append(c.Hosts, strings.Split(c.Host, ",")...)
	}
	if c.Inventory != "" {
		list, err := readIPList(c.Inventory)
		if err != nil {
			return err
		}
		c.Hosts = append(c.Hosts, list...)
	}
	if len(c.Hosts) == 0 {
		return errors.New("not found any remote host address")
	}
	c.Condition = make(map[string]string)
	c.labelToCondition()

	for _, connectIP := range c.Hosts {
		c.Condition["connect_ip"] = connectIP
		dbHost, dbBastion, err := c.getHostFromDB()
		if err != nil {
			return err
		}
		// 处理目标机器
		host := &Host{
			ConnectIP: connectIP,
		}

		if c.User != "" {
			host.SSHUser = c.User
		} else {
			if dbHost != nil && dbHost.SSHUser != "" {
				host.SSHUser = dbHost.SSHUser
			} else {
				host.SSHUser = os.Getenv("USER")
			}
		}

		if c.Port > 0 {
			host.SSHPort = c.Port
		} else {
			if dbHost != nil && dbHost.SSHPort > 0 {
				host.SSHPort = dbHost.SSHPort
			} else {
				host.SSHPort = 22
			}
		}

		if c.Password != "" {
			host.SSHPassword = c.Password
		} else if dbHost != nil {
			host.SSHPassword = dbHost.SSHPassword
		}

		if c.Key != "" {
			host.SSHKeyFile = c.Key
		} else {
			if dbHost != nil && dbHost.SSHKeyFile != "" {
				host.SSHKeyFile = replaceHomeDir(dbHost.SSHKeyFile)
			} else {
				host.SSHKeyFile = path.Join(os.Getenv("HOME"), ".ssh/id_rsa")
			}
		}

		if c.KeyPassphrase != "" {
			host.SSHKeyPassphrase = c.KeyPassphrase
		} else if dbHost != nil && dbHost.SSHKeyPassphrase != "" {
			host.SSHKeyPassphrase = dbHost.SSHKeyPassphrase
		}

		hostKey, err := readSSHKey(host.SSHKeyFile)
		if err != nil {
			return err
		}
		host.SSHKeyContent = hostKey
		host.SSHTimeout = c.ConnectTimeout
		// 处理跳板机
		if c.BastionHost == "" && dbBastion == nil {
			hostMap[connectIP] = host
			continue
		}

		bastion := &Bastion{}
		if c.BastionHost != "" {
			newBastion, _ := model.QueryBastionByCondition(map[string]string{"connect_ip": c.BastionHost})
			if newBastion != nil {
				dbBastion = newBastion
			}
			bastion.IP = c.BastionHost
		} else {
			bastion.IP = dbBastion.ConnectIP
		}

		host.BastionIP = bastion.IP
		if _, exist := bastionMap[bastion.IP]; exist {
			hostMap[connectIP] = host
			continue
		}
		hostMap[connectIP] = host

		if c.BastionUser != "" {
			bastion.SSHUser = c.BastionUser
		} else {
			if dbBastion != nil && dbBastion.SSHUser != "" {
				bastion.SSHUser = dbBastion.SSHUser
			} else {
				bastion.SSHUser = os.Getenv("USER")
			}
		}

		if c.BastionPort > 0 {
			bastion.SSHPort = c.BastionPort
		} else {
			if dbBastion != nil && dbBastion.SSHPort > 0 {
				bastion.SSHPort = dbBastion.SSHPort
			} else {
				bastion.SSHPort = 22
			}
		}

		if c.BastionPassword != "" {
			bastion.SSHPassword = c.BastionPassword
		} else if dbBastion != nil {
			bastion.SSHPassword = dbBastion.SSHPassword
		}

		if c.BastionKey != "" {
			bastion.SSHKeyFile = c.BastionKey
		} else {
			if dbBastion != nil && dbBastion.SSHKeyFile != "" {
				bastion.SSHKeyFile = replaceHomeDir(dbBastion.SSHKeyFile)
			} else {
				bastion.SSHKeyFile = path.Join(os.Getenv("HOME"), ".ssh/id_rsa")
			}
		}

		if c.BastionKeyPassphrase != "" {
			bastion.SSHKeyPassphrase = c.BastionKeyPassphrase
		} else if dbBastion != nil && dbBastion.SSHKeyPassphrase != "" {
			bastion.SSHKeyPassphrase = dbBastion.SSHKeyPassphrase
		}

		hostKey, err = readSSHKey(bastion.SSHKeyFile)
		if err != nil {
			return err
		}
		bastion.SSHKeyContent = hostKey
		bastion.SSHTimeout = c.ConnectTimeout
		bastionMap[bastion.IP] = bastion
	}

	return nil
}

func (c *Config) getHostFromDB() (*model.Host, *model.Bastion, error) {
	if !config.GlobalConfig.UsedDB {
		return nil, nil, nil
	}
	host, err := model.QueryHostByCondition(c.Condition)
	if err != nil || host == nil {
		return nil, nil, err
	}
	if host.BastionIP == "" {
		return host, nil, err
	}
	bastion, err := model.QueryBastionByCondition(map[string]string{"connect_ip": host.BastionIP})
	if err != nil {
		return nil, nil, err
	}
	return host, bastion, nil
}

func readIPList(file string) ([]string, error) {
	var res = make([]string, 0, 10)
	fd, err := os.Open(file)
	if err != nil {
		return nil, err
	}
	defer fd.Close()
	reader := bufio.NewReaderSize(fd, 128)
	reg := regexp.MustCompile("\\s+")
	for {
		line, _, err := reader.ReadLine()
		if err == io.EOF {
			break
		} else if err != nil {
			return nil, err
		}
		if len(line) == 0 {
			continue
		}
		match, _ := regexp.Match("^\\s*([0-9]{1,3}\\.){3}[0-9]{1,3}\\s*$", line)
		if match {
			addr := reg.ReplaceAllString(string(line), "")
			res = append(res, addr)
		}
	}
	return res, err
}

func readSSHKey(file string) (string, error) {
	readFile, err := ioutil.ReadFile(file)
	if err != nil {
		return "", err
	}
	return string(readFile), nil
}

func replaceHomeDir(file string) string {
	if strings.HasPrefix(file, "~") {
		return strings.Replace(file, "~", os.Getenv("HOME"), 1)
	}
	return file
}
